rm(list = ls())
library(tidyverse)
library(broom)
library(haven)
library(polycor)
library(wCorr)
library(quanteda)
library(quanteda.textplots)
require(glmnet)
require(quanteda.textstats)
library(ggcorrplot)
library(tidymodels)
set.seed(221186)

weightedCorr_na <- function(x,y,weights,...){
  
  x_missing <- which(is.na(x))
  y_missing <- which(is.na(y))
  missing <- unique(c(x_missing, y_missing))
  
  weights <- weights[-missing]
  
  weightedCorr(x[-missing], y[-missing], weights = weights, ...)  
}
mae <- function(x, weights) weighted.mean(abs(x - mean(x, na.rm = T)), na.rm = TRUE, w = weights)

n_boot_reps <- 500
n_ri_samples <- 500

load("../working/survey.Rdata")

## Treatment effects for stability

get_effects <- function(my_data, threshold = NULL){
  
  if(!is.null(threshold)){
    
    my_data$trans_rights[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$offensive_speech[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$minimum_wage[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$zero_hours[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$unemployment_support[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$high_tax[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    
    my_data$trans_rights_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$offensive_speech_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$minimum_wage_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$zero_hours_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$unemployment_support_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    my_data$high_tax_w2[my_data$treat == "Treatment" & my_data$duration_trans_text < threshold] <- NA
    
  }
  
  cat(".")
  my_data %>%
    group_by(treat) %>%
    summarise(MinimumWage = weightedCorr_na(minimum_wage, minimum_wage_w2, weights = weight, method = "Polychoric"),
              Tax = weightedCorr_na(high_tax, high_tax_w2, weights = weight, method = "Polychoric"),
              ZeroHours = weightedCorr_na(zero_hours, zero_hours_w2, weights = weight, method = "Polychoric"),
              Unemployment = weightedCorr_na(unemployment_support, unemployment_support_w2, weights = weight, method = "Polychoric"),
              Transgender = weightedCorr_na(trans_rights, trans_rights_w2, weights = weight, method = "Polychoric"),
              Offensive = weightedCorr_na(offensive_speech, offensive_speech_w2, weights = weight, method = "Polychoric")) %>%
    mutate(Mean = (MinimumWage + Tax + ZeroHours + Unemployment + Transgender + Offensive)/6) %>%
    pivot_longer(cols = -treat) %>%
    pivot_wider(names_from = treat) %>%
    mutate(diff = Treatment-Control)
}


# Estimate main effects

effects_est <- get_effects(just_panel)

boot_out <- tibble(fx = lapply(1:n_boot_reps, function(x) get_effects(just_panel[sample(1:nrow(just_panel), replace = T), ])))
boot_out$id <- 1:nrow(boot_out)

effect_ints <- boot_out %>% unnest(fx) %>%
  group_by(name) %>%
  summarise(high_diff = quantile(diff, 0.975),
            low_diff = quantile(diff, 0.025),
            
            high_control = quantile(Control, 0.975),
            low_control = quantile(Control, 0.025),
            
            high_treat = quantile(Treatment, 0.975),
            low_treat = quantile(Treatment, 0.025))

stability_effects <- full_join(effects_est, effect_ints, by = "name") %>%
  mutate(name = gsub("Offensive","Offensive Speech", name),
         name = gsub("Unemployment","Unemployment Support", name),
         name = gsub("MinimumWage","Minimum Wage", name),
         name = gsub("Tax","High Tax", name),
         name = gsub("Transgender","Transgender Rights", name),
         name = gsub("ZeroHours","Zero Hours Contracts", name)) %>%
  mutate(av_est = unique(diff[name == "Mean"]),
         av_hi = unique(high_diff[name == "Mean"]),
         av_lo = unique(low_diff[name == "Mean"])) %>%
  filter(name != "Mean")

stability_effects %>%
  mutate(name = factor(name, levels = name[order(diff)])) %>%
  ggplot(aes(x = diff, xmin = low_diff, xmax = high_diff, y = name)) + 
  geom_pointrange()+
  geom_vline(xintercept = 0, linetype = 2) + 
  geom_rect(aes(xmin = av_lo[1], xmax = av_hi[1], ymin = -Inf, ymax = Inf), fill = alpha("black", .04)) + 
  geom_vline(aes(xintercept = av_est), col = alpha("black",1)) + 
  geom_vline(aes(xintercept = av_lo), col = alpha("black",1)) + 
  geom_vline(aes(xintercept = av_hi), col = alpha("black",1)) +
  theme_bw() + 
  xlab("Effect of reason-giving on stability") + 
  #xlab(expression(rho["i,j"]^"D=1" - rho["i,j"]^"D=0")) + 
  ylab("") #+ 
  #scale_x_continuous(breaks = round(seq(-0.5, 0.5, .1),2), limits = c(-.35, .55)) + 
  #theme(panel.background = element_rect(fill = "transparent"),
   #     plot.background = element_rect(fill = "transparent", color = NA))

ggsave("../out/outcomes/stability.pdf", width = 5, height = 3)
ggsave("../out/outcomes/stability.png", dpi = 600, width = 5, height = 3)

stability_effects %>%
  select(name, Control, Treatment, high_treat, high_control, low_treat, low_control) %>%
  rename(est_control = Control,
         est_treat = Treatment) %>%
  pivot_longer(cols = -c(name),
               names_to = c(".value", "Var"),
               names_sep = "_") %>%
  mutate(Var = case_when(Var == "control" ~ "Control",
                         Var == "treat" ~ "Treatment")) %>%
  ggplot(aes(x = est, xmin = low, xmax = high, y = name, col = Var)) + 
  geom_pointrange(position = position_dodge(.5))+
  theme_bw() + 
  xlab("Over-time correlation") + 
  #xlab(expression(rho["i,j"]^"D=1" - rho["i,j"]^"D=0")) + 
  ylab("") + 
  #scale_x_continuous(breaks = round(seq(-0.5, 0.5, .1),2), limits = c(-.35, .55)) + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        legend.position = "bottom") + 
  scale_color_manual("", values = c("black", "gray")) + 
  scale_fill_manual("", values = c("black", "gray")) 

ggsave("../out/outcomes/stability_levels.pdf", width = 5, height = 3)

# Estimate main effects (attrition weighted)

just_panel_attr <- just_panel
just_panel_attr$weight <- just_panel_attr$weight * just_panel_attr$weight_attrition_panel
effects_est_attr <- get_effects(just_panel_attr)

boot_out_attr <- tibble(fx = lapply(1:n_boot_reps, function(x) get_effects(just_panel_attr[sample(1:nrow(just_panel_attr), replace = T), ])))
boot_out_attr$id <- 1:nrow(boot_out_attr)

effect_ints_attr <- boot_out_attr %>% unnest(fx) %>%
  group_by(name) %>%
  summarise(high_diff = quantile(diff, 0.975),
            low_diff = quantile(diff, 0.025),
            
            high_control = quantile(Control, 0.975),
            low_control = quantile(Control, 0.025),
            
            high_treat = quantile(Treatment, 0.975),
            low_treat = quantile(Treatment, 0.025))

stability_effects_attr <- full_join(effects_est_attr, effect_ints_attr, by = "name") %>%
  mutate(name = gsub("Offensive","Offensive Speech", name),
         name = gsub("Unemployment","Unemployment Support", name),
         name = gsub("MinimumWage","Minimum Wage", name),
         name = gsub("Tax","High Tax", name),
         name = gsub("Transgender","Transgender Rights", name),
         name = gsub("ZeroHours","Zero Hours Contracts", name)) %>%
  mutate(av_est = unique(diff[name == "Mean"]),
         av_hi = unique(high_diff[name == "Mean"]),
         av_lo = unique(low_diff[name == "Mean"])) %>%
  filter(name != "Mean")

stability_effects_attr %>%
  mutate(name = factor(name, levels = name[order(diff)])) %>%
  ggplot(aes(x = diff, xmin = low_diff, xmax = high_diff, y = name)) + 
  geom_pointrange()+
  geom_vline(xintercept = 0, linetype = 2) + 
  geom_rect(aes(xmin = av_lo[1], xmax = av_hi[1], ymin = -Inf, ymax = Inf), fill = alpha("black", .04)) + 
  geom_vline(aes(xintercept = av_est), col = alpha("black",1)) + 
  geom_vline(aes(xintercept = av_lo), col = alpha("black",1)) + 
  geom_vline(aes(xintercept = av_hi), col = alpha("black",1)) +
  theme_bw() + 
  xlab("Effects of reason-giving on stability") + 
  #xlab(expression(rho["i,j"]^"D=1" - rho["i,j"]^"D=0")) + 
  ylab("") + 
  #scale_x_continuous(breaks = round(seq(-0.5, 0.5, .1),2), limits = c(-.35, .55)) + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA))

ggsave("../out/outcomes/stability_attrition.pdf", width = 5, height = 3)


## Randomization inference

effects_est <- full_join(effects_est, effect_ints, by = "name")

just_panel$treat_original <- just_panel$treat

ri_out <- matrix(NA, ncol = nrow(get_effects(just_panel)), nrow = n_ri_samples)
colnames(ri_out) <- get_effects(just_panel)$name

for(i in 1:n_ri_samples){
  ri_out[i,] <- just_panel %>%
    mutate(treat = sample(treat_original, n(), replace = TRUE)) %>%
    get_effects() %>%
    .$diff  
}

effects_est$p_value <- NA
for(v in 1:length(effects_est$diff)) effects_est$p_value[v] <- 1 - mean(effects_est$diff[v] > abs(ri_out[,v]))

save(stability_effects, effects_est, file = "../working/effect_estimates_stability.Rdata")
load(file = "../working/effect_estimates_stability.Rdata")

# Heterogeneous effects

get_intervals <- function(my_data, wave){
  
  ests <- get_effects(my_data)
  boot_out <- tibble(fx = lapply(1:n_boot_reps, function(x) get_effects(my_data[sample(1:nrow(my_data), replace = T),])))
  
  ests <- boot_out %>% unnest(fx) %>%
    group_by(name) %>%
    summarise(mean = mean(diff),
              high = quantile(diff, 0.975),
              low = quantile(diff, 0.025))
  
  ests$Diff_low <- ests$low[ests$name == "Mean"]
  ests$Diff_high <- ests$high[ests$name == "Mean"]
  ests$Diff_mean <- ests$mean[ests$name == "Mean"]
  ests <- ests[ests$name != "Mean",]
  ests$est <- ests$mean
  ests$mean <- NULL
  ests %>% select(name, Diff_mean, Diff_low, Diff_high, high, low, est)
  
}

stability_effects_attention <- lapply(c("High", "Low"), function(x) get_intervals(just_panel[just_panel$attention_cat2 == x,], 1))
names(stability_effects_attention) <- c("High", "Low")

stability_effects_gender <- lapply(c("Male", "Female"), function(x) get_intervals(just_panel[just_panel$gender == x,], 1))
names(stability_effects_gender) <- c("Male", "Female")

stability_effects_age <- lapply(sort(unique(just_panel$age_cat)), function(x) get_intervals(just_panel[just_panel$age_cat == x,], 1))
names(stability_effects_age) <- sort(unique(just_panel$age_cat))

stability_effects_educ <- lapply(sort(unique(just_panel$education.x)), function(x) get_intervals(just_panel[just_panel$education.x == x,], 1))
names(stability_effects_educ) <- sort(unique(just_panel$education.x))

stability_effects_eu_vote <- lapply(sort(unique(just_panel$eu_vote)), function(x) get_intervals(just_panel[just_panel$eu_vote == x,], 1))
names(stability_effects_eu_vote) <- sort(unique(just_panel$eu_vote))

stability_effects_vote19 <- lapply(sort(unique(just_panel$vote19_simple)), function(x) get_intervals(just_panel[just_panel$vote19_simple == x,], 1))
names(stability_effects_vote19) <- sort(unique(just_panel$vote19_simple))

stability_het_effects <- list(attention = stability_effects_attention,
                              gender = stability_effects_gender,
                              age = stability_effects_age,
                              education = stability_effects_educ,
                              eu_vote = stability_effects_eu_vote,
                              vote19 = stability_effects_vote19)



save(stability_het_effects, file = "../working/het_effects_stability.Rdata")

# Effects by treatment duration

effects_est_thresh <- get_effects(just_panel, threshold = 30)

boot_out_thresh <- tibble(fx = lapply(1:n_boot_reps, function(x) get_effects(just_panel[sample(1:nrow(just_panel), replace = T), ], threshold = 30)))
boot_out_thresh$id <- 1:nrow(boot_out_thresh)

effect_ints_thresh <- boot_out_thresh %>% unnest(fx) %>%
  group_by(name) %>%
  summarise(high_diff = quantile(diff, 0.975),
            low_diff = quantile(diff, 0.025),
            
            high_control = quantile(Control, 0.975),
            low_control = quantile(Control, 0.025),
            
            high_treat = quantile(Treatment, 0.975),
            low_treat = quantile(Treatment, 0.025))

stability_effects_thresh <- full_join(effects_est_thresh, effect_ints_thresh, by = "name") %>%
  mutate(name = gsub("Offensive","Offensive Speech", name),
         name = gsub("Unemployment","Unemployment Support", name),
         name = gsub("MinimumWage","Minimum Wage", name),
         name = gsub("Tax","High Tax", name),
         name = gsub("Transgender","Transgender Rights", name),
         name = gsub("ZeroHours","Zero Hours Contracts", name)) %>%
  mutate(av_est = unique(diff[name == "Mean"]),
         av_hi = unique(high_diff[name == "Mean"]),
         av_lo = unique(low_diff[name == "Mean"])) %>%
  filter(name != "Mean")


save(stability_effects_thresh, file = "../working/dur_effects_stability.Rdata")
